% [pert_est cv_est med_est a_est] = EMAlgorithm(X,K,iteraciones)
% Algoritmo EM para clasificar las columnas de X en K clases de variables
% gaussianas.

function [pert_est cv_est med_est a_est logL] = MEMAlgorithm(X,Kini,Kmin,iteraciones,med_ini)

[dim N] = size(X);      % Dimensi�n y Cantidad de muestras
if nargin <3
    iteraciones = 40;      % Iteraciones del algoritmo
end

K = Kini;

% Inicializo el algoritmo.
randn('state', 5874);
% Inicializo probabilidades a priori de las clases
a_ini = ones(1,K)/K;
% Incializo las matrices de covarianza
% cv_ini = zeros(dim,dim,K);
cv_ini = repmat(cov(X'),[1 1 K]);

% Inicializo las medias si no fueron ingresadas
if nargin < 4
    orden = randperm(N);
    med_ini = zeros(dim,K);
    med_ini(:) = X(:,orden(1:K));
else

    dist = sum(abs(repmat(X,[1 1 K]) - repmat(permute(med_ini,[1 3 2]),[1 N 1])));
    dist = permute(dist,[3 2 1]);
    [val ind] = min(dist);
    for k = 1:K
        cv_ini(:,:,k) = cov(X(:,ind==k)');
    end
end
% Algoritmo EM
a_est = a_ini;
cv_est = cv_ini;
med_est = med_ini;
inv_cv = zeros(size(cv_est));
const_cv = zeros(K,1);
% pi_k es la probabilidad de obtener la muestra i dado que pertenece a
% la clase k
pi_k = zeros(N,K);
h = waitbar(0,'EM');

% Seguimiento
logL = zeros(1,iteraciones);
variacion = zeros(1,iteraciones);
flag = 1;

while flag

    for iter = 1:iteraciones
        if sum(isnan(med_est(:))) > 0
            break
        end
        % PASO E
        for j = 1:K
            resta = X - repmat(med_est(:,j),1,N);
            % Inversa de la covarianza
            inv_cv(:,:,j) = inv(cv_est(:,:,j));
            % Determinante de la covarianza escalado
            const_cv(j) = 1/(2*pi)^(dim/2) / sqrt(det(cv_est(:,:,j)));
            pi_k(:,j) = (const_cv(j) * exp(-.5 * sum(resta .* (inv_cv(:,:,j)...
                * resta))) * a_est(j))';
        end

        % La log-likelihood es la funci�n que mide la probabilidad de obtener
        % el set de muestras X, se obtiene a partir del c�lculo pi_k*a_est.
        % Promedio de las probabilidades de obtener la muestra i ponderado por
        % la probabilidad de obtener una muestra de cada clase

        % pit es la probabilidad de obtener la muestra i
        pit = sum(pi_k .* repmat(a_est,N,1),2);

        logL(iter) = sum(log2(pit));

        % Probabilidad de cada muestra de pertenecer a cada clase
        pk_i = pi_k .*repmat(a_est,N,1)./ repmat(pit,[1 K 1]);

        % PASO - M
        % Actualizaciones
        a_est = sum(pk_i)/N;

        for j = 1:K
            % La media estimada de la clase k se obtiene del promedio de las muestras
            % ponderadas por la relaci�n entre su probabilidad de pertenecer a dicha
            % clase y la probabilidad de que una muestra pertenezca a la clase. De esta
            % manera se priorizan, por ejemplo, las muestras que pertenecen con mucha
            % seguridad a una clase de poca probabilidad de ocurrencia.
            med_est(:,j) = sum(X.* repmat(pk_i(:,j)',dim,1),2)/a_est(j)/N;
            % Para la covarianza se emplea una metodolog�a similar. Es decir, a las
            % muestras se las pondera por su peso en el set de muestras con pertenencia
            % a la clase.
            resta = X-repmat(med_est(:,j),1,N);
            cv_est(:,:,j) = resta * (repmat(pk_i(:,j),1,dim) .* resta') / N/...
                a_est(j);
        end
               

        % Se establece como l�mite para la iteraci�n al promedio de diferencias en
        % la log-likelihood en las �ltimas 10 iteraciones.
        variacion(:,iter) = sum(abs(diff(logL(:,max(1,iter-10):iter),[],2)),2);
        if variacion(iter) < 0.00001 && iter > 1
            break
        end
        waitbar(iter/iteraciones/(Kini-Kmin+1));
    end

    if K > Kmin
        Jmerge = pk_i'*pk_i;
        Vnormas = sqrt(diag(Jmerge));
        Mnormas = repmat(Vnormas',K,1).*repmat(Vnormas,1,K);
        Jmerge = Jmerge ./ Mnormas;
        %Jmerge = corr(pi_k);
        Jmerge = abs(Jmerge - Jmerge(1)*eye(K));
        %Jmerge = Jmerge .* ((1-a_est)'*(1-a_est)).^2;
        [J I] = max(Jmerge(:));
%         if J(1) < 0.9
%             break
%         end
        cli = ceil(I(1)/K);
        clj = rem(I(1),K);
        if clj == 0
            clj = K;
        end
        % Pierdo un cluster
        K = K - 1;
        % Arreglo las apriori
        a_est = tirar_atras(a_est,[cli,clj]);
        pij = a_est(end-1);
        pik = a_est(end);
        a_est(end-1) =  pij + pik;
        a_est(end) = [];
        % Arreglo las medias
        med_est = tirar_atras(med_est,[cli clj]);
        muj = med_est(:,end-1);
        muk = med_est(:,end);
        med_est(:,end-1) = (pij*muj + pik*muk) / a_est(end);
        med_est(:,end) = [];
        % Arreglo las covarianzas
        cv_est = tirar_atras(cv_est,[cli,clj],3);
        cv_est(:,:,end-1) = (pij*(cv_est(:,:,end-1)+muj*muj') + pik * (cv_est(:,:,end)+muk*muk'))/a_est(end) - med_est(:,end)*med_est(:,end)';
        cv_est(:,:,end) = [];
        % Arreglo los pi_k
        pi_k = zeros(N,K);
    else
        flag = 0;
    end
end

close(h);

[Prob pert_est] = max(pk_i(:,:),[],2);
